{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.mamba_vit import MambaVit\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "from torchvision.transforms import ToTensor\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = MambaViT(image_size=28, patch_size=4, num_classes=10, channels=1, n_layer=8, dim=32, pool=\"mean\").to(\"cuda\")\n",
    "m = m.train()\n",
    "optimizer = torch.optim.SGD(m.parameters(), lr=0.0005)\n",
    "loss_fn = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_train = torchvision.datasets.MNIST(\"\", download=True, train=True, transform=ToTensor())\n",
    "mnist_test = torchvision.datasets.MNIST(\"\", download=True, train=False, transform=ToTensor())\n",
    "train_dataloader = torch.utils.data.DataLoader(mnist_train, batch_size=1024, shuffle=True)\n",
    "test_dataloader = torch.utils.data.DataLoader(mnist_train, batch_size=1024, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(100):\n",
    "    train_loss_list = []\n",
    "    for img, gt in tqdm(train_dataloader):\n",
    "        img = img.to(\"cuda\")\n",
    "        gt = gt.to(\"cuda\")\n",
    "        pred = m(img)\n",
    "        loss = loss_fn(input=pred, target=gt)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss_list.append(loss)\n",
    "    validation_loss_list = []\n",
    "    accuracy_list = []\n",
    "    for img, gt in tqdm(test_dataloader):\n",
    "        img = img.to(\"cuda\")\n",
    "        gt = gt.to(\"cuda\")\n",
    "        with torch.no_grad():\n",
    "            pred = m(img)\n",
    "        loss = loss_fn(input=pred, target=gt)\n",
    "        validation_loss_list.append(loss.cpu().detach())\n",
    "        accuracy_list.append(pred.softmax(-1).argmax(-1).cpu() == gt.cpu())\n",
    "    print(\"Training loss: \", torch.mean(torch.stack(train_loss_list)).item())\n",
    "    print(\"Validation loss: \", torch.mean(torch.stack(validation_loss_list)).item())\n",
    "    print(\"Validation accuracy: \", torch.mean(torch.cat(accuracy_list).float()).item())"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
